# Funk, Paul and Philips
# "Point Break: Using Machine Learning to Uncover a Critical Mass in Women’s Representation", Forthcoming at PSRM
#
#
# Last Updated: 7/19/21
# Files created:
# -Figure 3c: "VIPfull_defense.pdf"
# -Figure 3a: "VIPfull_edu.pdf"
# -Figure 3b: "VIPfull_health.pdf"
# -Figure 4: "pctwomen_3expenditures_PDP.pdf"
# -Figure 5a: "pctwomen_democracyinteraction_edu.pdf"
# -Figure 5b: "pctwomen_democracyinteraction_health.pdf"
# -Figure 5c: "pctwomen_democracyinteraction_mil.pdf"
# -Figure 6a: "pctwomen_yearinteraction_edu.pdf"
# -Figure 6b: "pctwomen_yearinteraction_health.pdf"
# -Figure 6c: "pctwomen_yearinteraction_mil.pdf"
# -Figure 7: "pctwomen_3expenditures_implementedquotainteraction.pdf"
# -SI Figure 1: "pctwomen_3expenditures_ICE.pdf"
# -SI Figure 2: "alternative_dependence_edu.pdf"
# -SI Figure 3: "alternative_dependence_health.pdf"
# -SI Figure 4: "alternative_dependence_mil.pdf"
# -SI Figure 8a: "alternative_estimators_edu.pdf"
# -SI Figure 8b: "alternative_estimators_health.pdf" -SI Figure 8c: "alternative_estimators_mil.pdf"
# ----------------------------------------------------------------------------
setwd("~/Dropbox/Kendall-Hannah-Andy/final-replication-june2021")
set.seed(590381)

library(randomForest)
library(gbm)
library(pdp)
library(lattice)
library(caret)
library(ggplot2)
library(haven)
library(ranger)
library(viridis)
library(haven)
library(rdd)
library(ALEPlot)
library(scales)
library(lime)       
library(vip) 
library(DALEX)
library(dplyr)
library(mgcv) # for GAMs
library(devtools)
#devtools::install_github("AlexAfanasev/NeuralNetworkVisualization") # https://github.com/AlexAfanasev/NeuralNetworkVisualization
library(NeuralNetworkVisualization) # for visualizing Neural networks
library(neuralnet) # underlying neural net package used ^^
library(segmented) # for piecewise regression

wdi.slim <- read_dta('funk-paul-philips-replication.dta') # load data


# ----------- CREATE SMALLER DATASET OF ONLY NEEDED VARIABLES ------------------ #
myvars <- c("percent_education", "percent_military", "percent_health", # our DVs
            "percent_women_Q", #  
            "polity2", # democracy measure
            # Quota variables to use:
            "implementedquota",   # dummy--country has implemented a gender quota in an election. Coded ‘1’ beginning in the year a quota has been implemented in an election -- whether or not the law was followed -- and in all subsequent years, unless the quota is overturned or withdrawn.         
            "defactothreshold",              
            "quotastrength1",               
            "quotastrength2",                
            "quotashock",
            "year", # stuff for ID'ing
            'Agedependencyratioofworki', # ADR
            'Agricultureforestryandfishi', #  Agriculture, forestry, and fishing, value added (% of GDP)
            'Birthratecrudeper1000peo', #  Birth rate, crude (per 1,000 people)
            'Employmenttopopulationratio', #  Employment to population ratio, 15+, total (%) 
            'Fertilityratetotalbirthspe', # Fertility rate, total (births per woman)
            'Foreigndirectinvestmentneti', #  Foreign direct investment, net inflows (% of GDP)
            'GDPgrowthannualNYGDPMK', #  GDP growth (annual %)
            'GDPpercapitaconstant2010US', #  GDP per capita (constant 2010 US$)
            'Importsofgoodsandservices', # Imports of goods and services (% of GDP)
            'InflationGDPdeflatorannual', #  Inflation, GDP deflator (annual %)
            'Laborforceparticipationrate', # Labor force participation rate, female (% of female population ages 15+) 
            'Laborforcefemaleoftotal', #  Labor force, female (% of total labor force)
            'P', #  Labor force participation rate, male (% of male population ages 15+) 
            'Lifeexpectancyatbirthfemale', #  Life expectancy at birth, female (years)
            'Lifeexpectancyatbirthmale', #  Life expectancy at birth, male (years)
            'Lifetimeriskofmaternaldeath', #  Lifetime risk of maternal death (%)
            'Maternalmortalityratiomodele', #  Maternal mortality ratio (modeled estimate, per 100,000 live births) 
            'PopulationgrowthannualSP', #  Population growth (annual %)
            'Populationdensitypeoplepers', #  Population density (people per sq. km of land area)
            'PopulationtotalSPPOPTOTL', #  Population, total
            'Populationfemaleoftotal', #   Population, female (% of total) 
            'Prevalenceofanemiaamongnonp', #  Prevalence of anemia among non-pregnant women (% of women ages 15-49)
            'Ruralpopulationoftotalpop', #  Rural population (% of total population) 
            'Schoolenrollmentprimarygr', #  School enrollment, preprimary (% gross)
            'TradeofGDPNETRDGNFSZS', #  Trade (% of GDP)
            'Unemploymenttotaloftotal', #  Unemployment, total (% of total labor force)
            'Unemploymentmaleofmalela'  # Unemployment, male (% of male labor force)
)

wdi.slim <- wdi.slim[myvars]
wdi.slim <- na.omit(wdi.slim)
nrow(wdi.slim) # 1265 obs total
# turn dichotomous ones into factors:
wdi.slim$implementedquota <- as.factor(wdi.slim$implementedquota)

# set hyperparameters
n.trees <- 300
node.size <- 1 
mtry <- 14

# ------------------------- RUN RANDOM FORESTS ON EACH DV ---------------------------------- 
rf.edu <- randomForest(percent_education  ~ . - percent_health - percent_military, data = wdi.slim, mtry = mtry, node.size = node.size, importance = TRUE, ntree = n.trees)

rf.health <- randomForest(percent_health  ~ . - percent_education - percent_military, data = wdi.slim, mtry = mtry, node.size = node.size, importance = TRUE, ntree = n.trees) 

rf.mil <- randomForest(percent_military  ~ . - percent_education - percent_health, data = wdi.slim, mtry = mtry, node.size = node.size, importance = TRUE, ntree = n.trees)
# ------------------------------------------------------------------------------------------ 


# ------------------------- CREATE VARIABLE IMPORTANT PLOTS ---------------------------------- #
# Make VIP's by hand. For defense
vip.dat <- as.data.frame(cbind(names(rf.mil$importance[,1]), rf.mil$importance[,1]/rf.mil$importanceSD))
vip.dat$V2 <- as.numeric(as.character(vip.dat$V2))
vip.dat <- vip.dat[order(-vip.dat$V2),] # sort from largest to smallest

# Need to label everything:
vip.dat$varnames[vip.dat$V1 == "Populationfemaleoftotal"] <- "% Female Population"
vip.dat$varnames[vip.dat$V1 == "PopulationtotalSPPOPTOTL"] <- "Total Population"
vip.dat$varnames[vip.dat$V1 == "Populationdensitypeoplepers"] <- "Population Density"
vip.dat$varnames[vip.dat$V1 == "Ruralpopulationoftotalpop"] <- "% Rural Population"
vip.dat$varnames[vip.dat$V1 == "Importsofgoodsandservices"] <- "Imports"
vip.dat$varnames[vip.dat$V1 == "TradeofGDPNETRDGNFSZS"] <- "Trade"
vip.dat$varnames[vip.dat$V1 == "Lifeexpectancyatbirthfemale"] <- "Female Life Expectancy"
vip.dat$varnames[vip.dat$V1 == "Prevalenceofanemiaamongnonp"] <- "Anemia Prevalence (% of Women)"
vip.dat$varnames[vip.dat$V1 == "percent_women_Q"] <- "% Women in Legislature"
vip.dat$varnames[vip.dat$V1 == "Agedependencyratioofworki"] <- "Age Dependency Ratio"
vip.dat$varnames[vip.dat$V1 == "P"] <- "Male Labor Force Participation Rate"
vip.dat$varnames[vip.dat$V1 == "Laborforcefemaleoftotal"] <- "% Labor Force Female"
vip.dat$varnames[vip.dat$V1 == "year"] <- "Year"
vip.dat$varnames[vip.dat$V1 == "Schoolenrollmentprimarygr"] <- "% School Enrollment"
vip.dat$varnames[vip.dat$V1 == "Unemploymenttotaloftotal"] <- "Unemployment Rate"
vip.dat$varnames[vip.dat$V1 == "Lifeexpectancyatbirthmale"] <- "Male Life Expectancy"
vip.dat$varnames[vip.dat$V1 == "Unemploymentmaleofmalela"] <-  "Male Unemployment (%)"
vip.dat$varnames[vip.dat$V1 == "Laborforceparticipationrate"] <- "Female Labor Force Participation Rate"
vip.dat$varnames[vip.dat$V1 == "GDPpercapitaconstant2010US"] <- "GDP Per Capita"
vip.dat$varnames[vip.dat$V1 == "Agricultureforestryandfishi"] <- "% of GDP from Agriculture"
vip.dat$varnames[vip.dat$V1 == "PopulationgrowthannualSP"] <- "Population Growth"
vip.dat$varnames[vip.dat$V1 == "GDPgrowthannualNYGDPMK"] <- "GDP Growth"
vip.dat$varnames[vip.dat$V1 == "polity2"] <- "Polity"
vip.dat$varnames[vip.dat$V1 == "Employmenttopopulationratio"] <- "Employment to Population Ratio"
vip.dat$varnames[vip.dat$V1 == "Fertilityratetotalbirthspe"] <- "Fertility Rate"
vip.dat$varnames[vip.dat$V1 == "Foreigndirectinvestmentneti"] <- "FDI"
vip.dat$varnames[vip.dat$V1 == "Birthratecrudeper1000peo"] <- "Birth Rate"
vip.dat$varnames[vip.dat$V1 == "Maternalmortalityratiomodele"] <- "Maternal Mortality Ratio"
vip.dat$varnames[vip.dat$V1 == "InflationGDPdeflatorannual"] <- "Inflation"
vip.dat$varnames[vip.dat$V1 == "Lifetimeriskofmaternaldeath"] <- "Maternal Mortality Risk (Lifetime)"
vip.dat$varnames[vip.dat$V1 == "quotastrength1"] <- "Quota Strength (1)"
vip.dat$varnames[vip.dat$V1 == "quotastrength2"] <- "Quota Strength (2)"
vip.dat$varnames[vip.dat$V1 == "implementedquota"] <- "Implemented Quota"
vip.dat$varnames[vip.dat$V1 == "quotashock"] <- "Quota Shock"
vip.dat$varnames[vip.dat$V1 == "defactothreshold"] <- "De Facto Threshold" 

# Plot VIFs, sorting by importance and highlighting % women
graph.full <- ggplot(vip.dat, aes(x = reorder(varnames, -V2), y = V2, fill=factor(ifelse(V1=='percent_women_Q', 'Normal', 'Highlighted')))) + geom_bar( stat = 'identity') + theme_minimal() + theme(axis.text.x=element_text(angle=75, hjust=1)) + xlab('') + ylab('% Increase in MSE') + ggtitle('Defense') + theme(legend.position = "none")
pdf(file= 'VIPfull_defense.pdf', width=8, height=5.5)
graph.full
dev.off()

# For education
vip.dat <- as.data.frame(cbind(names(rf.edu$importance[,1]), rf.edu$importance[,1]/rf.edu$importanceSD))
vip.dat$V2 <- as.numeric(as.character(vip.dat$V2))
vip.dat <- vip.dat[order(-vip.dat$V2),] # sort from largest to smallest
# master name maker for the full dataset:
vip.dat$varnames[vip.dat$V1 == "Populationfemaleoftotal"] <- "% Female Population"
vip.dat$varnames[vip.dat$V1 == "PopulationtotalSPPOPTOTL"] <- "Total Population"
vip.dat$varnames[vip.dat$V1 == "Populationdensitypeoplepers"] <- "Population Density"
vip.dat$varnames[vip.dat$V1 == "Ruralpopulationoftotalpop"] <- "% Rural Population"
vip.dat$varnames[vip.dat$V1 == "Importsofgoodsandservices"] <- "Imports"
vip.dat$varnames[vip.dat$V1 == "TradeofGDPNETRDGNFSZS"] <- "Trade"
vip.dat$varnames[vip.dat$V1 == "Lifeexpectancyatbirthfemale"] <- "Female Life Expectancy"
vip.dat$varnames[vip.dat$V1 == "Prevalenceofanemiaamongnonp"] <- "Anemia Prevalence (% of Women)"
vip.dat$varnames[vip.dat$V1 == "percent_women_Q"] <- "% Women in Legislature"
vip.dat$varnames[vip.dat$V1 == "Agedependencyratioofworki"] <- "Age Dependency Ratio"
vip.dat$varnames[vip.dat$V1 == "P"] <- "Male Labor Force Participation Rate"
vip.dat$varnames[vip.dat$V1 == "Laborforcefemaleoftotal"] <- "% Labor Force Female"
vip.dat$varnames[vip.dat$V1 == "year"] <- "Year"
vip.dat$varnames[vip.dat$V1 == "Schoolenrollmentprimarygr"] <- "% School Enrollment"
vip.dat$varnames[vip.dat$V1 == "Unemploymenttotaloftotal"] <- "Unemployment Rate"
vip.dat$varnames[vip.dat$V1 == "Lifeexpectancyatbirthmale"] <- "Male Life Expectancy"
vip.dat$varnames[vip.dat$V1 == "Unemploymentmaleofmalela"] <-  "Male Unemployment (%)"
vip.dat$varnames[vip.dat$V1 == "Laborforceparticipationrate"] <- "Female Labor Force Participation Rate"
vip.dat$varnames[vip.dat$V1 == "GDPpercapitaconstant2010US"] <- "GDP Per Capita"
vip.dat$varnames[vip.dat$V1 == "Agricultureforestryandfishi"] <- "% of GDP from Agriculture"
vip.dat$varnames[vip.dat$V1 == "PopulationgrowthannualSP"] <- "Population Growth"
vip.dat$varnames[vip.dat$V1 == "GDPgrowthannualNYGDPMK"] <- "GDP Growth"
vip.dat$varnames[vip.dat$V1 == "polity2"] <- "Polity"
vip.dat$varnames[vip.dat$V1 == "Employmenttopopulationratio"] <- "Employment to Population Ratio"
vip.dat$varnames[vip.dat$V1 == "Fertilityratetotalbirthspe"] <- "Fertility Rate"
vip.dat$varnames[vip.dat$V1 == "Foreigndirectinvestmentneti"] <- "FDI"
vip.dat$varnames[vip.dat$V1 == "Birthratecrudeper1000peo"] <- "Birth Rate"
vip.dat$varnames[vip.dat$V1 == "Maternalmortalityratiomodele"] <- "Maternal Mortality Ratio"
vip.dat$varnames[vip.dat$V1 == "InflationGDPdeflatorannual"] <- "Inflation"
vip.dat$varnames[vip.dat$V1 == "Lifetimeriskofmaternaldeath"] <- "Maternal Mortality Risk (Lifetime)"
vip.dat$varnames[vip.dat$V1 == "quotastrength1"] <- "Quota Strength (1)"
vip.dat$varnames[vip.dat$V1 == "quotastrength2"] <- "Quota Strength (2)"
vip.dat$varnames[vip.dat$V1 == "implementedquota"] <- "Implemented Quota"
vip.dat$varnames[vip.dat$V1 == "quotashock"] <- "Quota Shock"
vip.dat$varnames[vip.dat$V1 == "defactothreshold"] <- "De Facto Threshold" 

graph.full <- ggplot(vip.dat, aes(x = reorder(varnames, -V2), y = V2, fill=factor(ifelse(V1=='percent_women_Q', 'Normal', 'Highlighted')))) + geom_bar(stat = 'identity') + theme_minimal() + theme(axis.text.x=element_text(angle=75, hjust=1)) + xlab('') + ylab('% Increase in MSE') + ggtitle('Education') + theme(legend.position = "none")
pdf(file= 'VIPfull_edu.pdf', width=8, height=5.5)
graph.full
dev.off()

# For health
vip.dat <- as.data.frame(cbind(names(rf.health$importance[,1]), rf.health$importance[,1]/rf.health$importanceSD))
vip.dat$V2 <- as.numeric(as.character(vip.dat$V2))
vip.dat <- vip.dat[order(-vip.dat$V2),] # sort from largest to smallest

# master name maker for the full dataset:
vip.dat$varnames[vip.dat$V1 == "Populationfemaleoftotal"] <- "% Female Population"
vip.dat$varnames[vip.dat$V1 == "PopulationtotalSPPOPTOTL"] <- "Total Population"
vip.dat$varnames[vip.dat$V1 == "Populationdensitypeoplepers"] <- "Population Density"
vip.dat$varnames[vip.dat$V1 == "Ruralpopulationoftotalpop"] <- "% Rural Population"
vip.dat$varnames[vip.dat$V1 == "Importsofgoodsandservices"] <- "Imports"
vip.dat$varnames[vip.dat$V1 == "TradeofGDPNETRDGNFSZS"] <- "Trade"
vip.dat$varnames[vip.dat$V1 == "Lifeexpectancyatbirthfemale"] <- "Female Life Expectancy"
vip.dat$varnames[vip.dat$V1 == "Prevalenceofanemiaamongnonp"] <- "Anemia Prevalence (% of Women)"
vip.dat$varnames[vip.dat$V1 == "percent_women_Q"] <- "% Women in Legislature"
vip.dat$varnames[vip.dat$V1 == "Agedependencyratioofworki"] <- "Age Dependency Ratio"
vip.dat$varnames[vip.dat$V1 == "P"] <- "Male Labor Force Participation Rate"
vip.dat$varnames[vip.dat$V1 == "Laborforcefemaleoftotal"] <- "% Labor Force Female"
vip.dat$varnames[vip.dat$V1 == "year"] <- "Year"
vip.dat$varnames[vip.dat$V1 == "Schoolenrollmentprimarygr"] <- "% School Enrollment"
vip.dat$varnames[vip.dat$V1 == "Unemploymenttotaloftotal"] <- "Unemployment Rate"
vip.dat$varnames[vip.dat$V1 == "Lifeexpectancyatbirthmale"] <- "Male Life Expectancy"
vip.dat$varnames[vip.dat$V1 == "Unemploymentmaleofmalela"] <-  "Male Unemployment (%)"
vip.dat$varnames[vip.dat$V1 == "Laborforceparticipationrate"] <- "Female Labor Force Participation Rate"
vip.dat$varnames[vip.dat$V1 == "GDPpercapitaconstant2010US"] <- "GDP Per Capita"
vip.dat$varnames[vip.dat$V1 == "Agricultureforestryandfishi"] <- "% of GDP from Agriculture"
vip.dat$varnames[vip.dat$V1 == "PopulationgrowthannualSP"] <- "Population Growth"
vip.dat$varnames[vip.dat$V1 == "GDPgrowthannualNYGDPMK"] <- "GDP Growth"
vip.dat$varnames[vip.dat$V1 == "polity2"] <- "Polity"
vip.dat$varnames[vip.dat$V1 == "Employmenttopopulationratio"] <- "Employment to Population Ratio"
vip.dat$varnames[vip.dat$V1 == "Fertilityratetotalbirthspe"] <- "Fertility Rate"
vip.dat$varnames[vip.dat$V1 == "Foreigndirectinvestmentneti"] <- "FDI"
vip.dat$varnames[vip.dat$V1 == "Birthratecrudeper1000peo"] <- "Birth Rate"
vip.dat$varnames[vip.dat$V1 == "Maternalmortalityratiomodele"] <- "Maternal Mortality Ratio"
vip.dat$varnames[vip.dat$V1 == "InflationGDPdeflatorannual"] <- "Inflation"
vip.dat$varnames[vip.dat$V1 == "Lifetimeriskofmaternaldeath"] <- "Maternal Mortality Risk (Lifetime)"
vip.dat$varnames[vip.dat$V1 == "quotastrength1"] <- "Quota Strength (1)"
vip.dat$varnames[vip.dat$V1 == "quotastrength2"] <- "Quota Strength (2)"
vip.dat$varnames[vip.dat$V1 == "implementedquota"] <- "Implemented Quota"
vip.dat$varnames[vip.dat$V1 == "quotashock"] <- "Quota Shock"
vip.dat$varnames[vip.dat$V1 == "defactothreshold"] <- "De Facto Threshold" 

graph.full <- ggplot(vip.dat, aes(x = reorder(varnames, -V2), y = V2, fill=factor(ifelse(V1=='percent_women_Q', 'Normal', 'Highlighted')))) + geom_bar(stat = 'identity') + theme_minimal() + theme(axis.text.x=element_text(angle=75, hjust=1)) + xlab('') + ylab('% Increase in MSE') + ggtitle('Health') + theme(legend.position = "none")
pdf(file= 'VIPfull_health.pdf', width=8, height=5.5)
graph.full
dev.off()
# ------------------------------------------------------------------------------------------ #



# ------------------------- CREATE PARTIAL DEPENDENCE PLOTS ---------------------------------- #
# initialize prediction function for each DV
set.seed(590381)
pred <- function(object, newdata) predict(object, newdata)
pdp.edu <- partial(rf.edu, pred.var = 'percent_women_Q', pred.fun = pred)
pdp.health <- partial(rf.health, pred.var = 'percent_women_Q', pred.fun = pred)
pdp.mil <- partial(rf.mil, pred.var = 'percent_women_Q', pred.fun = pred)
p1 <- ggplot() + stat_summary(data = pdp.edu, aes(percent_women_Q, yhat), fun.y = mean, geom = 'line', col = 'black', size = 1) + theme_minimal() + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Education') + geom_rug(data =wdi.slim, aes(x = percent_women_Q))
p2 <- ggplot() + stat_summary(data = pdp.health, aes(percent_women_Q, yhat), fun.y = mean, geom = 'line', col = 'black', size = 1) + theme_minimal() + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Health') + geom_rug(data =wdi.slim, aes(x = percent_women_Q))
p3 <- ggplot() + stat_summary(data = pdp.mil, aes(percent_women_Q, yhat), fun.y = mean, geom = 'line', col = 'black', size = 1) + theme_minimal() + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Defense') + geom_rug(data =wdi.slim, aes(x = percent_women_Q))
pdf(file= 'pctwomen_3expenditures_PDP.pdf', width=8, height=3.5)
grid.arrange(p1, p2, p3, ncol = 3)
dev.off()
# ------------------------------------------------------------------------------------------ #


# ------------------------- CREATE PDPs w/ INTERACTIONS ---------------------------------------- #

# Interaction w/ Democracy: --------
pdp.edu.inter <- partial(rf.edu, pred.var = c('percent_women_Q', 'polity2'), pred.fun = pred, chull = TRUE)
agg.pdp.edu.inter <- aggregate(pdp.edu.inter, by = list(pdp.edu.inter$percent_women_Q, pdp.edu.inter$polity2), FUN=mean)
g1 <- ggplot(agg.pdp.edu.inter, aes(x = percent_women_Q, y = polity2)) +
  geom_tile(aes(x = percent_women_Q, y = polity2, fill = yhat)) +
  scale_fill_gradientn("Expenditures\n(% of GDP)", colours=c("#0000FFFF","#FF0000FF")) + theme_minimal() + xlab('% Women in Legislature') + ylab('Polity Score') + ggtitle('Education') + theme(legend.title = element_text(size=5), legend.text = element_text(size = 5), legend.position = 'bottom') + guides(color = guide_legend(override.aes = list(size = 0.3)))

pdp.health.inter <- partial(rf.health, pred.var = c('percent_women_Q', 'polity2'), pred.fun = pred, chull = TRUE)
agg.pdp.health.inter <- aggregate(pdp.health.inter, by = list(pdp.health.inter$percent_women_Q, pdp.health.inter$polity2), FUN=mean)
g2 <- ggplot(agg.pdp.health.inter, aes(x = percent_women_Q, y = polity2)) + geom_tile(aes(x = percent_women_Q, y = polity2, fill = yhat)) + scale_fill_gradientn("Expenditures\n(% of GDP)", colours=c("#0000FFFF","#FF0000FF")) + theme_minimal() + xlab('% Women in Legislature') + ylab('Polity Score') + ggtitle('Health') + theme(legend.title = element_text(size=5), legend.text = element_text(size = 5), legend.position = 'bottom') + guides(color = guide_legend(override.aes = list(size = 0.3))) 

pdp.mil.inter <- partial(rf.mil, pred.var = c('percent_women_Q', 'polity2'), pred.fun = pred, chull = TRUE)
agg.pdp.mil.inter <- aggregate(pdp.mil.inter, by = list(pdp.mil.inter$percent_women_Q, pdp.mil.inter$polity2), FUN=mean)
g3 <- ggplot(agg.pdp.mil.inter, aes(x = percent_women_Q, y = polity2)) + geom_tile(aes(x = percent_women_Q, y = polity2, fill = yhat)) + theme_minimal() + xlab('% Women in Legislature') + ylab('Polity Score') + ggtitle('Defense') + scale_fill_gradientn("Expenditures\n(% of GDP)", colours=c("#0000FFFF","#FF0000FF")) + theme(legend.title = element_text(size=5), legend.text = element_text(size = 5), legend.position = 'bottom') + guides(barwidth = .32)

# plot:
pdf(file= 'pctwomen_democracyinteraction_edu.pdf', width = 4, height = 4)
g1
dev.off()
pdf(file= 'pctwomen_democracyinteraction_health.pdf', width = 4, height = 4)
g2
dev.off()
pdf(file= 'pctwomen_democracyinteraction_mil.pdf', width = 4, height = 4)
g3
dev.off()

# Interaction with year ----------------------------
pdp.edu.inter <- partial(rf.edu, pred.var = c('percent_women_Q', 'year'), pred.fun = pred, chull = TRUE)
agg.pdp.edu.inter <- aggregate(pdp.edu.inter, by = list(pdp.edu.inter$percent_women_Q, pdp.edu.inter$year), FUN=mean)
g1 <- ggplot(agg.pdp.edu.inter, aes(x = percent_women_Q, y = year)) + geom_tile(aes(x = percent_women_Q, y = year, fill = yhat)) + scale_fill_gradientn("Expenditures\n(% of GDP)", colours=c("#0000FFFF","#FF0000FF")) +   theme_minimal() + xlab('% Women in Legislature') + ylab('Year') + ggtitle('Education') + theme(legend.title = element_text(size=5), legend.text = element_text(size = 5), legend.position = 'bottom') + guides(color = guide_legend(override.aes = list(size = 0.3)))
plot(g1)

pdp.health.inter <- partial(rf.health, pred.var = c('percent_women_Q', 'year'), pred.fun = pred, chull = TRUE)
agg.pdp.health.inter <- aggregate(pdp.health.inter, by = list(pdp.health.inter$percent_women_Q, pdp.health.inter$year), FUN=mean)
g2 <- ggplot(agg.pdp.health.inter, aes(x = percent_women_Q, y = year)) + geom_tile(aes(x = percent_women_Q, y = year, fill = yhat)) + scale_fill_gradientn("Expenditures\n(% of GDP)", colours=c("#0000FFFF","#FF0000FF")) + theme_minimal() + xlab('% Women in Legislature') + ylab('Year') + ggtitle('Health') + theme(legend.title = element_text(size=5), legend.text = element_text(size = 5), legend.position = 'bottom') + guides(color = guide_legend(override.aes = list(size = 0.3)))  

pdp.mil.inter <- partial(rf.mil, pred.var = c('percent_women_Q', 'year'), pred.fun = pred, chull = TRUE)
agg.pdp.mil.inter <- aggregate(pdp.mil.inter, by = list(pdp.mil.inter$percent_women_Q, pdp.mil.inter$year), FUN=mean)
g3 <- ggplot(agg.pdp.mil.inter, aes(x = percent_women_Q, y = year)) + geom_tile(aes(x = percent_women_Q, y = year, fill = yhat)) + theme_minimal() + xlab('% Women in Legislature') + ylab('Year') + ggtitle('Defense') + scale_fill_gradientn("Expenditures\n(% of GDP)", colours=c("#0000FFFF","#FF0000FF")) + theme(legend.title = element_text(size=5), legend.text = element_text(size = 5), legend.position = 'bottom') + guides(barwidth = .32)

# plot
pdf(file= 'pctwomen_yearinteraction_edu.pdf', width = 4, height = 4)
g1
dev.off()
pdf(file= 'pctwomen_yearinteraction_health.pdf', width = 4, height = 4)
g2
dev.off()
pdf(file= 'pctwomen_yearinteraction_mil.pdf', width = 4, height = 4)
g3
dev.off()

# Across implementedquota (dichotomous) ------------------
# Note: Dashed = implemented quota, solid = no quota
pdp.edu.inter <- partial(rf.edu, pred.var = c('percent_women_Q', 'implementedquota'), pred.fun = pred, chull = TRUE)
pdp.edu.inter$implementedquota <- as.numeric(as.character(pdp.edu.inter$implementedquota))
agg.pdp.edu.inter <- aggregate(pdp.edu.inter, by = list(pdp.edu.inter$percent_women_Q, pdp.edu.inter$implementedquota), FUN=mean)
p1 <- ggplot(agg.pdp.edu.inter, aes(x = percent_women_Q, y = yhat, group = factor(implementedquota))) + geom_line(aes(linetype = factor(implementedquota))) + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Education') + theme_minimal() + theme(legend.position = "none") 

pdp.health.inter <- partial(rf.health, pred.var = c('percent_women_Q', 'implementedquota'), pred.fun = pred, chull = TRUE)
pdp.health.inter$implementedquota <- as.numeric(as.character(pdp.health.inter$implementedquota))
agg.pdp.health.inter <- aggregate(pdp.health.inter, by = list(pdp.health.inter$percent_women_Q, pdp.health.inter$implementedquota), FUN=mean)
p2 <- ggplot(agg.pdp.health.inter, aes(x = percent_women_Q, y = yhat, group = factor(implementedquota))) + geom_line(aes(linetype = factor(implementedquota))) + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Health') + theme_minimal() + theme(legend.position = "none")

pdp.mil.inter <- partial(rf.mil, pred.var = c('percent_women_Q', 'implementedquota'), pred.fun = pred, chull = TRUE)
pdp.mil.inter$implementedquota <- as.numeric(as.character(pdp.mil.inter$implementedquota))
agg.pdp.mil.inter <- aggregate(pdp.mil.inter, by = list(pdp.mil.inter$percent_women_Q, pdp.mil.inter$implementedquota), FUN=mean)
p3 <- ggplot(agg.pdp.mil.inter, aes(x = percent_women_Q, y = yhat, group = factor(implementedquota))) + geom_line(aes(linetype = factor(implementedquota))) + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Defense') + theme_minimal() + theme(legend.position = "none")

pdf(file= 'pctwomen_3expenditures_implementedquotainteraction.pdf', width=8, height=3.5)
grid.arrange(p1, p2, p3, ncol = 3)
dev.off()
# ------------------------------------------------------------------------------------------ #


# ------------------------- CREATE ICE PLOTS ------------------------------------------- #
p1 <- ggplot() + geom_line(data = pdp.edu, aes(x = percent_women_Q, y =yhat, group = yhat.id), alpha = 0.2) + stat_summary(data =pdp.edu, aes(percent_women_Q, yhat), fun.y = mean, geom = 'line', col = 'red', size = 1.5) + theme_minimal() + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Education') + geom_rug(data = wdi.slim, aes(x = percent_women_Q)) 
p2 <- ggplot() + geom_line(data = pdp.health, aes(x = percent_women_Q, y =yhat, group = yhat.id), alpha = 0.2) + stat_summary(data =pdp.health, aes(percent_women_Q, yhat), fun.y = mean, geom = 'line', col = 'red', size = 1.5) + theme_minimal() + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Health') + geom_rug(data = wdi.slim, aes(x = percent_women_Q)) 
p3 <- ggplot() + geom_line(data = pdp.mil, aes(x = percent_women_Q, y =yhat, group = yhat.id), alpha = 0.2) + stat_summary(data =pdp.mil, aes(percent_women_Q, yhat), fun.y = mean, geom = 'line', col = 'red', size = 1.5) + theme_minimal() + xlab('% Women in Legislature') + ylab('Predicted Value') + ggtitle('Defense') + geom_rug(data = wdi.slim, aes(x = percent_women_Q)) 

pdf(file= 'pctwomen_3expenditures_ICE.pdf', width=8, height=3.5)
grid.arrange(p1, p2, p3, ncol = 3)
dev.off()
# ------------------------------------------------------------------------------------------ #



# ---------------------------- CREATE DALEX EXPLAINERS ---------------------------------- #
# exclude the DVs (needed or else plots post-RF screws up)
predictors <- as.data.frame(wdi.slim[,-c(1:3)]) # just predictors
just.edu.data <- data.frame(wdi.slim$percent_education, predictors)
just.health.data <- data.frame(wdi.slim$percent_health, predictors)
just.mil.data <- data.frame(wdi.slim$percent_military, predictors)

# re-run RFs, omitting the other DVs:
set.seed(590381)
rf.edu <- randomForest(wdi.slim.percent_education  ~ . , data = just.edu.data, mtry = mtry, node.size = node.size, importance = TRUE, ntree = n.trees)

rf.health <- randomForest(wdi.slim.percent_health  ~ . , data = just.health.data, mtry = mtry, node.size = node.size, importance = TRUE, ntree = n.trees)

rf.mil <- randomForest(wdi.slim.percent_military  ~ . , data = just.mil.data, mtry = mtry, node.size = node.size, importance = TRUE, ntree = n.trees)

explainer.rf.edu <- DALEX::explain(model = rf.edu, data = predictors, y = just.edu.data$wdi.slim.percent_education)

explainer.rf.health <- DALEX::explain(model = rf.health, data = predictors, y = just.health.data$wdi.slim.percent_health)

explainer.rf.mil <- DALEX::explain(model = rf.mil, data = predictors, y = just.mil.data$wdi.slim.percent_military)
# ------------------------------------------------------------------------------------------ #


# ---------------------------- CREATE PLOTS WITH DALEX ---------------------------------- #
# first get local dependence predictions:
ld.edu <- model_profile(explainer = explainer.rf.edu, type = 'conditional', variables  = c('percent_women_Q'))
# and ALE;
ale.edu <- model_profile(explainer = explainer.rf.edu, type = 'accumulated', variables  = c('percent_women_Q'))
# create partial dependence plot 
pdp.edu <- model_profile(explainer = explainer.rf.edu, type = 'partial', variables  = c('percent_women_Q'))
pdp.edu$agr_profiles$`_label_` = "partial dependence" # labels
ld.edu$agr_profiles$`_label_` = "local dependence"
ale.edu$agr_profiles$`_label_` = "accumulated local effect"

ld.health <- model_profile(explainer = explainer.rf.health, type = 'conditional', variables  = c('percent_women_Q'))
# and ALE;
ale.health <- model_profile(explainer = explainer.rf.health, type = 'accumulated', variables  = c('percent_women_Q'))
# create partial dependence plot 
pdp.health <- model_profile(explainer = explainer.rf.health, type = 'partial', variables  = c('percent_women_Q'))
pdp.health$agr_profiles$`_label_` = "partial dependence" # labels
ld.health$agr_profiles$`_label_` = "local dependence"
ale.health$agr_profiles$`_label_` = "accumulated local effect"

ld.mil <- model_profile(explainer = explainer.rf.mil, type = 'conditional', variables  = c('percent_women_Q'))
# and ALE;
ale.mil <- model_profile(explainer = explainer.rf.mil, type = 'accumulated', variables  = c('percent_women_Q'))
# create partial dependence plot 
pdp.mil <- model_profile(explainer = explainer.rf.mil, type = 'partial', variables  = c('percent_women_Q'))
pdp.mil$agr_profiles$`_label_` = "partial dependence" # labels
ld.mil$agr_profiles$`_label_` = "local dependence"
ale.mil$agr_profiles$`_label_` = "accumulated local effect"

# setup for ggplot
masterp.edu <- data.frame(pdp.edu$agr_profiles$`_x_`, pdp.edu$agr_profiles$`_yhat_`, # pdp
                          ld.edu$agr_profiles$`_x_`, ld.edu$agr_profiles$`_yhat_`, # ld
                          ale.edu$agr_profiles$`_x_`, ale.edu$agr_profiles$`_yhat_`) # ale
assertthat::are_equal(masterp.edu$pdp.edu.agr_profiles.._x_., masterp.edu$ld.edu.agr_profiles.._x_.) # make sure x's match
assertthat::are_equal(masterp.edu$pdp.edu.agr_profiles.._x_., masterp.edu$ale.edu.agr_profiles.._x_.) # make sure x's match
# for edu:
pdf(file= 'alternative_dependence_edu.pdf') 
ggplot(aes(x = pdp.edu.agr_profiles.._x_.), data = masterp.edu) +
  geom_line(aes(y = pdp.edu.agr_profiles.._yhat_.), color = '#a1dab4', size = 1.2, linetype = 'solid') +
  geom_line(aes(y = ld.edu.agr_profiles.._yhat_.), color = '#41b6c4', size = 1.2, linetype = 'longdash') + 
  geom_line(aes(y = ale.edu.agr_profiles.._yhat_.), color = '#2c7fb8', size = 1.2, linetype = 'dotted') +
  ylab('Predicted Value') + ggtitle('Education') +
  xlab('% of Women Legislators') + theme_minimal()
dev.off()

# now health
masterp.health <- data.frame(pdp.health$agr_profiles$`_x_`, pdp.health$agr_profiles$`_yhat_`, # pdp
                             ld.health$agr_profiles$`_x_`, ld.health$agr_profiles$`_yhat_`, # ld
                             ale.health$agr_profiles$`_x_`, ale.health$agr_profiles$`_yhat_`) # ale
pdf(file= 'alternative_dependence_health.pdf')
ggplot(aes(x = pdp.health.agr_profiles.._x_.), data = masterp.health) +
  geom_line(aes(y = pdp.health.agr_profiles.._yhat_.), color = '#a1dab4', size = 1.2, linetype = 'solid') +
  geom_line(aes(y = ld.health.agr_profiles.._yhat_.), color = '#41b6c4', size = 1.2, linetype = 'longdash') + 
  geom_line(aes(y = ale.health.agr_profiles.._yhat_.), color = '#2c7fb8', size = 1.2, linetype = 'dotted') +
  ylab('Predicted Value') + ggtitle('Health') +
  xlab('% of Women Legislators') + theme_minimal()
dev.off()

# now defense
masterp.mil <- data.frame(pdp.mil$agr_profiles$`_x_`, pdp.mil$agr_profiles$`_yhat_`, # pdp
                          ld.mil$agr_profiles$`_x_`, ld.mil$agr_profiles$`_yhat_`, # ld
                          ale.mil$agr_profiles$`_x_`, ale.mil$agr_profiles$`_yhat_`) # ale
pdf(file= 'alternative_dependence_mil.pdf')
ggplot(aes(x = pdp.mil.agr_profiles.._x_.), data = masterp.mil) +
  geom_line(aes(y = pdp.mil.agr_profiles.._yhat_.), color = '#a1dab4', size = 1.2, linetype = 'solid') +
  geom_line(aes(y = ld.mil.agr_profiles.._yhat_.), color = '#41b6c4', size = 1.2, linetype = 'longdash') + 
  geom_line(aes(y = ale.mil.agr_profiles.._yhat_.), color = '#2c7fb8', size = 1.2, linetype = 'dotted') +
  ylab('Predicted Value') + ggtitle('Defense') +
  xlab('% of Women Legislators') + theme_minimal()
dev.off()
# ------------------------------------------------------------------------------------------ #



# ------------------------- ALTERNATIVE ESTIMATORS ---------------------------- #
set.seed(590381)
# create master "all else equal" by putting everything at it mean. We'll use this for GAM, polynomial and piecewise when we're creating predictions
getmode <- function(v) { # function to get modes
  uniqv <- unique(v)
  uniqv[which.max(tabulate(match(v, uniqv)))]
}

counterfactual <- data.frame(percent_women_Q = sort(unique(wdi.slim$percent_women_Q)), 
                             polity2 = mean(wdi.slim$polity2),
                             implementedquota = getmode(wdi.slim$implementedquota),
                             defactothreshold = mean(wdi.slim$defactothreshold),
                             quotastrength1 = mean(wdi.slim$quotastrength1),
                             quotastrength2 = mean(wdi.slim$quotastrength2),
                             quotashock = mean(wdi.slim$quotashock),
                             year = mean(wdi.slim$year),
                             Agedependencyratioofworki = mean(wdi.slim$Agedependencyratioofworki),
                             Agricultureforestryandfishi = mean(wdi.slim$Agricultureforestryandfishi),
                             Birthratecrudeper1000peo = mean(wdi.slim$Birthratecrudeper1000peo),
                             Employmenttopopulationratio = mean(wdi.slim$Employmenttopopulationratio),
                             Fertilityratetotalbirthspe = mean(wdi.slim$Fertilityratetotalbirthspe),
                             Foreigndirectinvestmentneti = mean(wdi.slim$Foreigndirectinvestmentneti),
                             GDPgrowthannualNYGDPMK = mean(wdi.slim$GDPgrowthannualNYGDPMK),
                             GDPpercapitaconstant2010US = mean(wdi.slim$GDPpercapitaconstant2010US),
                             Importsofgoodsandservices = mean(wdi.slim$Importsofgoodsandservices),
                             InflationGDPdeflatorannual = mean(wdi.slim$InflationGDPdeflatorannual),
                             Laborforceparticipationrate = mean(wdi.slim$Laborforceparticipationrate),
                             Laborforcefemaleoftotal = mean(wdi.slim$Laborforcefemaleoftotal),
                             P = mean(wdi.slim$P),
                             Lifeexpectancyatbirthfemale = mean(wdi.slim$Lifeexpectancyatbirthfemale),
                             Lifeexpectancyatbirthmale = mean(wdi.slim$Lifeexpectancyatbirthmale),
                             Lifetimeriskofmaternaldeath = mean(wdi.slim$Lifetimeriskofmaternaldeath),
                             Maternalmortalityratiomodele = mean(wdi.slim$Maternalmortalityratiomodele),
                             PopulationgrowthannualSP = mean(wdi.slim$PopulationgrowthannualSP),
                             Populationdensitypeoplepers = mean(wdi.slim$Populationdensitypeoplepers),
                             PopulationtotalSPPOPTOTL = mean(wdi.slim$PopulationtotalSPPOPTOTL),
                             Populationfemaleoftotal = mean(wdi.slim$Populationfemaleoftotal),
                             Prevalenceofanemiaamongnonp = mean(wdi.slim$Prevalenceofanemiaamongnonp),
                             Ruralpopulationoftotalpop = mean(wdi.slim$Ruralpopulationoftotalpop),
                             Schoolenrollmentprimarygr = mean(wdi.slim$Schoolenrollmentprimarygr),
                             TradeofGDPNETRDGNFSZS = mean(wdi.slim$TradeofGDPNETRDGNFSZS),
                             Unemploymenttotaloftotal = mean(wdi.slim$Unemploymenttotaloftotal),
                             Unemploymentmaleofmalela = mean(wdi.slim$Unemploymentmaleofmalela))


# ------------- CUBIC POLYNOMIAL --------------------
poly.edu <- lm(percent_education  ~ poly(percent_women_Q, 3) + polity2 + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + PopulationtotalSPPOPTOTL + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim)
predict.poly.edu <- predict(poly.edu, newdata=counterfactual, type='response', se=F)

poly.health <- lm(percent_health  ~ poly(percent_women_Q, 3) + polity2 + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + PopulationtotalSPPOPTOTL + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim)
predict.poly.health <- predict(poly.health, newdata=counterfactual, type='response', se=F)

poly.mil <- lm(percent_military  ~ poly(percent_women_Q, 3) + polity2 + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + PopulationtotalSPPOPTOTL + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim)
predict.poly.mil <- predict(poly.mil, newdata=counterfactual, type='response', se=F)


# ------------- PIECEWISE REGRESSION ----------------
# Note that a singular matrix error results unless PopulationtotalSPPOPTOTL is removed
set.seed(590381)
pr.edu <- segmented(lm(percent_education  ~ percent_women_Q + polity2 + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim), seg.Z = ~ percent_women_Q, npsi = 2, control = seg.control(display = FALSE))
predict.pr.edu <- predict(pr.edu, newdata=counterfactual, type='response', se=F)

set.seed(590381)
pr.health <- segmented(lm(percent_health  ~ percent_women_Q + polity2 + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim), seg.Z = ~ percent_women_Q, npsi = 2, control = seg.control(display = FALSE))
predict.pr.health <- predict(pr.health, newdata=counterfactual, type='response', se=F)

set.seed(590381)
pr.mil <- segmented(lm(percent_military  ~ percent_women_Q + polity2 + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim), seg.Z = ~ percent_women_Q, npsi = 2, control = seg.control(display = FALSE))
predict.pr.mil <- predict(pr.mil, newdata=counterfactual, type='response', se=F)

# ------------- GAM ---------------------------------
# for GAM, allowing the 10 most important predictors to have non-linear effect, plus % women.
# DV = education: pop density, total pop, % female pop, % female lab force, polity, anemia prevalence (% women), GDP per capita, imports, % rural pop, female lab force participation rate -> Populationdensitypeoplepers, PopulationtotalSPPOPTOTL, Populationfemaleoftotal, Laborforcefemaleoftotal, polity2, Prevalenceofanemiaamongnonp, GDPpercapitaconstant2010US, Importsofgoodsandservices, Ruralpopulationoftotalpop, Laborforceparticipationrate
set.seed(590381)
gam.edu <- gam(percent_education  ~ s(percent_women_Q) + s(polity2) + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + s(GDPpercapitaconstant2010US) + s(Importsofgoodsandservices) + InflationGDPdeflatorannual +  s(Laborforceparticipationrate) + s(Laborforcefemaleoftotal) + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + s(Populationdensitypeoplepers) + s(PopulationtotalSPPOPTOTL) + s(Populationfemaleoftotal) + s(Prevalenceofanemiaamongnonp) + s(Ruralpopulationoftotalpop) + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim)
predict.gam.edu <- predict(gam.edu, newdata=counterfactual, type='response', se=F)

# DV = health: % female pop, total pop, pop density, % rural pop, trade, imports, female life expectancy, year, anemia prevalence (% women), % school enrollment -> Populationfemaleoftotal, PopulationtotalSPPOPTOTL, Populationdensitypeoplepers, Ruralpopulationoftotalpop, TradeofGDPNETRDGNFSZS, Importsofgoodsandservices, Lifeexpectancyatbirthfemale, year, Prevalenceofanemiaamongnonp, Schoolenrollmentprimarygr
set.seed(590381)
gam.health <- gam(percent_health ~ s(percent_women_Q) + polity2 + implementedquota + defactothreshold + quotastrength1 + quotastrength2 + quotashock + s(year) + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + s(Importsofgoodsandservices) + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + s(Lifeexpectancyatbirthfemale) + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + s(Populationdensitypeoplepers) + s(PopulationtotalSPPOPTOTL) + s(Populationfemaleoftotal) + s(Prevalenceofanemiaamongnonp) + s(Ruralpopulationoftotalpop) + s(Schoolenrollmentprimarygr) + s(TradeofGDPNETRDGNFSZS) + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim)
predict.gam.health <- predict(gam.health, newdata=counterfactual, type='response', se=F)

# DV = military: polity, % rural pop, % female pop, pop density, % lab force, quota strength 2, age dep ratio, maternal mortality rate, female lab force participation rate, fertility rate, imports -> polity2, Ruralpopulationoftotalpop, Populationfemaleoftotal, Populationdensitypeoplepers, Laborforcefemaleoftotal, quotastrength2, Agedependencyratioofworki, Maternalmortalityratiomodele, Laborforceparticipationrate, Fertilityratetotalbirthspe, Importsofgoodsandservices
# Note that quotastrength2 has only 6 unique values, so curbing to k=3
set.seed(590381)
gam.mil <- gam(percent_military  ~ s(percent_women_Q) + s(polity2) + implementedquota + defactothreshold + quotastrength1 + s(quotastrength2, k = 3) + quotashock + year + s(Agedependencyratioofworki) + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + s(Fertilityratetotalbirthspe) + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + s(Importsofgoodsandservices) + InflationGDPdeflatorannual +  s(Laborforceparticipationrate) + s(Laborforcefemaleoftotal) + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + s(Maternalmortalityratiomodele) + PopulationgrowthannualSP + s(Populationdensitypeoplepers) + PopulationtotalSPPOPTOTL + s(Populationfemaleoftotal) + Prevalenceofanemiaamongnonp + s(Ruralpopulationoftotalpop) + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim)
predict.gam.mil <- predict(gam.mil, newdata=counterfactual, type='response', se=F)

# ------------ NEURAL NET ---------------------------
set.seed(590381)
nn.edu <- NeuralNetwork(percent_education  ~ percent_women_Q + polity2 + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + PopulationtotalSPPOPTOTL + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim, layers = c(17), scale = TRUE, linear.output = TRUE, threshold = 0.01, algorithm = 'rprop+', learningrate = 0.01, rep = 1)
nn.edu.pdp <- plot_partial_dependencies(nn.edu, predictors = 'percent_women_Q')

set.seed(590381)
nn.health <- NeuralNetwork(percent_health  ~ percent_women_Q + polity2 + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + PopulationtotalSPPOPTOTL + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim, layers = c(17), scale = TRUE, linear.output = TRUE, threshold = 0.01, algorithm = 'rprop+', learningrate = 0.01, rep = 1)
nn.health.pdp <- plot_partial_dependencies(nn.health, predictors = 'percent_women_Q')

set.seed(590381)
nn.mil <- NeuralNetwork(percent_military ~ percent_women_Q + polity2 + defactothreshold + quotastrength1 + quotastrength2 + quotashock + year + Agedependencyratioofworki + Agricultureforestryandfishi + Birthratecrudeper1000peo + Employmenttopopulationratio + Fertilityratetotalbirthspe + Foreigndirectinvestmentneti + GDPgrowthannualNYGDPMK + GDPpercapitaconstant2010US + Importsofgoodsandservices + InflationGDPdeflatorannual +  Laborforceparticipationrate + Laborforcefemaleoftotal + P + Lifeexpectancyatbirthfemale + Lifeexpectancyatbirthmale + Lifetimeriskofmaternaldeath + Maternalmortalityratiomodele + PopulationgrowthannualSP + Populationdensitypeoplepers + PopulationtotalSPPOPTOTL + Populationfemaleoftotal + Prevalenceofanemiaamongnonp + Ruralpopulationoftotalpop + Schoolenrollmentprimarygr + TradeofGDPNETRDGNFSZS + Unemploymenttotaloftotal + Unemploymentmaleofmalela, data = wdi.slim, layers = c(17), scale = TRUE, linear.output = TRUE, threshold = 0.01, algorithm = 'rprop+', learningrate = 0.01, rep = 1)
nn.mil.pdp <- plot_partial_dependencies(nn.mil, predictors = 'percent_women_Q')

# ------------ JOIN EVERYTHING TOGETHER IN SINGLE PLOTS ---------------------------
# combine everything that uses counterfactual:
predicts <- data.frame(counterfactual, predict.pr.edu, predict.pr.health, predict.pr.mil, predict.gam.edu, predict.gam.health, predict.gam.mil, predict.poly.edu, predict.poly.health, predict.poly.mil)
# join in the NN data:
nn.predicts.edu <- data.frame(nn.edu.pdp$data$percent_women_Q, nn.edu.pdp$data$yhat)
predicts <- merge(x = predicts, y = nn.predicts.edu, by.x = c('percent_women_Q'), by.y = c('nn.edu.pdp.data.percent_women_Q'), all.x = TRUE, all.y = TRUE, no.dups = FALSE)
nn.predicts.health <- data.frame(nn.health.pdp$data$percent_women_Q, nn.health.pdp$data$yhat)
predicts <- merge(x = predicts, y = nn.predicts.health, by.x = c('percent_women_Q'), by.y = c('nn.health.pdp.data.percent_women_Q'), all.x = TRUE, all.y = TRUE, no.dups = FALSE)
nn.predicts.mil <- data.frame(nn.mil.pdp$data$percent_women_Q, nn.mil.pdp$data$yhat)
predicts <- merge(x = predicts, y = nn.predicts.mil, by.x = c('percent_women_Q'), by.y = c('nn.mil.pdp.data.percent_women_Q'), all.x = TRUE, all.y = TRUE, no.dups = FALSE)

# for edu:
pdf(file= 'alternative_estimators_edu.pdf', width = 4, height = 4)
ggplot(aes(x = percent_women_Q), data = predicts) +
  geom_line(aes(y = predict.pr.edu), color = '#a1dab4', size = 1.2, linetype = 'solid') +
  geom_line(aes(y = predict.gam.edu), color = '#41b6c4', size = 1.2, linetype = 'longdash') + 
  geom_line(aes(y = predict.poly.edu), color = '#2c7fb8', size = 1.2, linetype = 'dotted') +
  geom_line(aes(y = nn.edu.pdp.data.yhat), size = 1.2, color = '#253494', linetype = 'dotdash' ) +
  ylab('Predicted Value of Education Expenditures') + 
  xlab('% of Women Legislators') + theme_minimal()
dev.off()

# for health:
pdf(file= 'alternative_estimators_health.pdf', width = 4, height = 4)
ggplot(aes(x = percent_women_Q), data = predicts) +
  geom_line(aes(y = predict.pr.health), color = '#a1dab4', size = 1.2, linetype = 'solid') +
  geom_line(aes(y = predict.gam.health), color = '#41b6c4', size = 1.2, linetype = 'longdash') + 
  geom_line(aes(y = predict.poly.health), color = '#2c7fb8', size = 1.2, linetype = 'dotted') +
  geom_line(aes(y = nn.health.pdp.data.yhat), size = 1.2, color = '#253494' , linetype = 'dotdash') +
  ylab('Predicted Value of Health Expenditures') + 
  xlab('% of Women Legislators') + theme_minimal()
dev.off()

# for defense:
pdf(file= 'alternative_estimators_mil.pdf', width = 4, height = 4)
ggplot(aes(x = percent_women_Q), data = predicts) +
  geom_line(aes(y = predict.pr.mil), color = '#a1dab4', size = 1.2, linetype = 'solid') +
  geom_line(aes(y = predict.gam.mil), color = '#41b6c4', size = 1.2, linetype = 'longdash') + 
  geom_line(aes(y = predict.poly.mil), color = '#2c7fb8', size = 1.2, linetype = 'dotted') +
  geom_line(aes(y = nn.mil.pdp.data.yhat), size = 1.5, color = '#253494', linetype = 'dotdash' ) +
  ylab('Predicted Value of Defense Expenditures') + 
  xlab('% of Women Legislators') + theme_minimal()
dev.off()
# ---------------------------------------------------------------------------- #

